library(ISLR2)
library(mlr)
## Loading required package: ParamHelpers
library(mlr3)
## Registered S3 method overwritten by 'paradox':
## method from
## c.ParamSet ParamHelpers
##
## Attaching package: 'mlr3'
## The following objects are masked from 'package:mlr':
##
## benchmark, resample
library(mlr3learners)
library(mlr3verse)
library(mlr3tuning)
## Loading required package: paradox
library(iml)
This document demonstrates the use of the GADGET package to build interpretable trees based on local feature effects. We showcase both synthetic XOR data and a real-world Bikeshare dataset.
The GADGET (Generalized Additive Decomposition for Global Explanation Trees) package provides interpretable model explanation through regionally partitioned trees, built upon local feature effect estimates. In this notebook, we demonstrate how GADGET can be used to construct and visualize global explanations from local interpretation methods such as ICE (Individual Conditional Expectation) or PDP (Partial Dependence Plots).
The following core functions are provided to users:
compute_tree()
Constructs a recursive explanation tree by optimizing an objective
function (e.g., ICE-based or PDP-based L2 reduction) at each node. It
supports both numeric and categorical features and allows flexible
control over split granularity and minimal node sizes.
extract_split_criteria() Summarizes the full
structure of the computed explanation tree, including split features,
values, node depth, and objective improvement. It can also extract
criteria related to a specific feature.
plot_tree() Visualizes the ICE or PDP behavior of
each feature in the tree across split regions. Regional plots help users
interpret how a model behaves differently across subgroups.
plot_tree_structure() Draws the structure of the
explanation tree itself, with node annotations and split
conditions.
This notebook walks through two examples:
A synthetic XOR-like dataset to illustrate behavior in controlled settings with known interactions.
The Bikeshare dataset from the ISLR2 package, demonstrating usage on real-world, heterogeneous data.
The GADGET package is especially useful when interpreting black-box models (e.g., neural networks, random forests) in terms of their localized behavior across feature space.
The synthetic data set is constructed to mimic an XOR-like interaction structure with noise. The response variable y is defined as:
\[ y = \begin{cases} +3x_1, & \text{if } x_3 > 0 \\\\ -3x_1, & \text{if } x_3 \leq 0 \end{cases} + x_3 + \varepsilon \]
where \(\varepsilon \sim \mathcal{N}(0, 0.3^2)\), and all covariates \(x_1, x_2, x_3 \sim \mathcal{U}(-1, 1)\) independently.
This setup creates a nonlinear response surface with sharp localized directional changes depending on the signs of \(x_3\) and \(x_4\), making it suitable for evaluating interpretable model partitioning via ICE/PDP.
## [Tune] Started tuning learner regr.nnet for parameter set:
## Type len Def Constr Req Tunable Trafo
## decay discrete - - 0.5,0.1,0.01,0.001,1e-04,1e-05 - TRUE -
## size discrete - - 3,5,10,20,30 - TRUE -
## With control class: TuneControlGrid
## Imputation value: InfImputation value: InfImputation value: Inf
## [Tune-x] 1: decay=0.5; size=3
## [Tune-y] 1: mse.test.mean=0.6342609,mae.test.mean=0.5851592,rsq.test.mean=0.8059670; time: 0.0 min
## [Tune-x] 2: decay=0.1; size=3
## [Tune-y] 2: mse.test.mean=0.5969903,mae.test.mean=0.5617087,rsq.test.mean=0.8171748; time: 0.0 min
## [Tune-x] 3: decay=0.01; size=3
## [Tune-y] 3: mse.test.mean=0.5760032,mae.test.mean=0.5523641,rsq.test.mean=0.8240000; time: 0.0 min
## [Tune-x] 4: decay=0.001; size=3
## [Tune-y] 4: mse.test.mean=0.7500930,mae.test.mean=0.6212394,rsq.test.mean=0.7718377; time: 0.0 min
## [Tune-x] 5: decay=1e-04; size=3
## [Tune-y] 5: mse.test.mean=0.9251352,mae.test.mean=0.7089577,rsq.test.mean=0.7146937; time: 0.0 min
## [Tune-x] 6: decay=1e-05; size=3
## [Tune-y] 6: mse.test.mean=0.8405647,mae.test.mean=0.6632766,rsq.test.mean=0.7438060; time: 0.0 min
## [Tune-x] 7: decay=0.5; size=5
## [Tune-y] 7: mse.test.mean=0.4992945,mae.test.mean=0.4933752,rsq.test.mean=0.8475109; time: 0.0 min
## [Tune-x] 8: decay=0.1; size=5
## [Tune-y] 8: mse.test.mean=0.3934074,mae.test.mean=0.4357591,rsq.test.mean=0.8795930; time: 0.0 min
## [Tune-x] 9: decay=0.01; size=5
## [Tune-y] 9: mse.test.mean=0.3613862,mae.test.mean=0.4275523,rsq.test.mean=0.8894339; time: 0.0 min
## [Tune-x] 10: decay=0.001; size=5
## [Tune-y] 10: mse.test.mean=0.4156572,mae.test.mean=0.4658912,rsq.test.mean=0.8735117; time: 0.0 min
## [Tune-x] 11: decay=1e-04; size=5
## [Tune-y] 11: mse.test.mean=0.4474042,mae.test.mean=0.4846311,rsq.test.mean=0.8629718; time: 0.0 min
## [Tune-x] 12: decay=1e-05; size=5
## [Tune-y] 12: mse.test.mean=0.3499822,mae.test.mean=0.4148599,rsq.test.mean=0.8936464; time: 0.0 min
## [Tune-x] 13: decay=0.5; size=10
## [Tune-y] 13: mse.test.mean=0.4890649,mae.test.mean=0.4868183,rsq.test.mean=0.8504408; time: 0.0 min
## [Tune-x] 14: decay=0.1; size=10
## [Tune-y] 14: mse.test.mean=0.2923974,mae.test.mean=0.3640902,rsq.test.mean=0.9106366; time: 0.0 min
## [Tune-x] 15: decay=0.01; size=10
## [Tune-y] 15: mse.test.mean=0.2609520,mae.test.mean=0.3297380,rsq.test.mean=0.9207877; time: 0.0 min
## [Tune-x] 16: decay=0.001; size=10
## [Tune-y] 16: mse.test.mean=0.2375898,mae.test.mean=0.3291642,rsq.test.mean=0.9282886; time: 0.0 min
## [Tune-x] 17: decay=1e-04; size=10
## [Tune-y] 17: mse.test.mean=2.3642889,mae.test.mean=0.4328764,rsq.test.mean=0.1642193; time: 0.0 min
## [Tune-x] 18: decay=1e-05; size=10
## [Tune-y] 18: mse.test.mean=0.2631678,mae.test.mean=0.3636896,rsq.test.mean=0.9187290; time: 0.0 min
## [Tune-x] 19: decay=0.5; size=20
## [Tune-y] 19: mse.test.mean=0.4809190,mae.test.mean=0.4794360,rsq.test.mean=0.8531026; time: 0.0 min
## [Tune-x] 20: decay=0.1; size=20
## [Tune-y] 20: mse.test.mean=0.2871772,mae.test.mean=0.3596285,rsq.test.mean=0.9122657; time: 0.0 min
## [Tune-x] 21: decay=0.01; size=20
## [Tune-y] 21: mse.test.mean=0.2891807,mae.test.mean=0.3558324,rsq.test.mean=0.9120576; time: 0.0 min
## [Tune-x] 22: decay=0.001; size=20
## [Tune-y] 22: mse.test.mean=0.2744659,mae.test.mean=0.3541896,rsq.test.mean=0.9164298; time: 0.0 min
## [Tune-x] 23: decay=1e-04; size=20
## [Tune-y] 23: mse.test.mean=0.3111040,mae.test.mean=0.3720471,rsq.test.mean=0.9044419; time: 0.0 min
## [Tune-x] 24: decay=1e-05; size=20
## [Tune-y] 24: mse.test.mean=0.2997305,mae.test.mean=0.3580383,rsq.test.mean=0.9084981; time: 0.0 min
## [Tune-x] 25: decay=0.5; size=30
## [Tune-y] 25: mse.test.mean=0.4767558,mae.test.mean=0.4750559,rsq.test.mean=0.8543361; time: 0.0 min
## [Tune-x] 26: decay=0.1; size=30
## [Tune-y] 26: mse.test.mean=0.2899194,mae.test.mean=0.3622760,rsq.test.mean=0.9114141; time: 0.0 min
## [Tune-x] 27: decay=0.01; size=30
## [Tune-y] 27: mse.test.mean=0.2902570,mae.test.mean=0.3551265,rsq.test.mean=0.9118208; time: 0.0 min
## [Tune-x] 28: decay=0.001; size=30
## [Tune-y] 28: mse.test.mean=0.3308723,mae.test.mean=0.3842306,rsq.test.mean=0.8982397; time: 0.0 min
## [Tune-x] 29: decay=1e-04; size=30
## [Tune-y] 29: mse.test.mean=0.4041151,mae.test.mean=0.3993078,rsq.test.mean=0.8759221; time: 0.0 min
## [Tune-x] 30: decay=1e-05; size=30
## [Tune-y] 30: mse.test.mean=0.3751666,mae.test.mean=0.4071508,rsq.test.mean=0.8864029; time: 0.0 min
## [Tune] Result: decay=0.001; size=10 : mse.test.mean=0.2375898,mae.test.mean=0.3291642,rsq.test.mean=0.9282886
We first use the iml package to compute
Individual Conditional Expectation (ICE) curves for
each feature based on the trained neural network model. These ICE curves
capture local prediction behavior.
Next, we apply compute_tree() from the GADGET package to
build an interpretable explanation tree. The tree partitions the data
into regions where the PDP (partial dependence) behavior is relatively
homogeneous, as measured by the objective function
"SS_L2_pd".
The set Z specifies contextual features that may
influence the behavior of the features of interest. The resulting tree
identifies regions with distinct interaction patterns in model
predictions.
library(GADGET)
syn.tree = build_tree(effect = syn.effect,
data = syn.data,
effect.method = "pd",
target.feature.name = "y",
split.feature = NULL,
n.split = 2,
impr.par = 0.1,
n.quantiles = NULL,
min.node.size = 1)
syn.plot = plot_tree_pd(syn.tree, syn.effect, syn.data,
target.feature.name = "y",
show.plot = T, show.point = T, mean.center = T)
plot_2_1 = plot_node_pd(syn.plot, depth = 2, node.idx = 1)
plot_tree_structure(syn.tree)
extract_split_info(syn.tree)
## depth id n.obs child.type split.feature split.value objective.value
## 1 1 1 500 root x3 -0.002799334 62335.6117
## 2 2 2 242 left none NA 262.1979
## 3 2 3 258 right none NA 347.7720
## intImp intImp.x1 intImp.x2 intImp.x3 split.feature.parent
## 1 0.9902147 0.9862745 0.05518114 0.9963916 <NA>
## 2 NA NA NA NA x3
## 3 NA NA NA NA x3
## split.value.parent objective.value.parent intImp_parent is.final
## 1 NA NA NA FALSE
## 2 -0.002799334 62335.61 0.9902147 TRUE
## 3 -0.002799334 62335.61 0.9902147 TRUE
We outline several planned improvements and extensions to the GADGET package along two main directions:
shiny with ggparty to allow users to
interactively explore split trees and node-level plots.
find_best_binary_split,
perform_split,
generate_split_candidates).compute_tree()), visualization functions
(plot_tree(), plot_tree_structure()), and
general helpers (extract_split_criteria()).These enhancements will further improve the interpretability, usability, and scalability of the GADGET package for both academic and applied use cases.
The following figure shows the structure and flow of core functions and objects in the GADGET package.